{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "025a0e67",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "\n",
    "from sklearn.model_selection import train_test_split  # 这里是引用了交叉验证\n",
    "from sklearn.linear_model import LogisticRegression"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "a40c0adc",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data = pd.read_csv('iris_train.data')\n",
    "test_data = pd.read_csv('iris_test.data')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d58714a6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>花萼长度</th>\n",
       "      <th>花萼宽度</th>\n",
       "      <th>花瓣长度</th>\n",
       "      <th>花瓣宽度</th>\n",
       "      <th>种类</th>\n",
       "      <th>种类编码</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>5.1</td>\n",
       "      <td>3.5</td>\n",
       "      <td>1.4</td>\n",
       "      <td>0.2</td>\n",
       "      <td>Iris-setosa</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>4.9</td>\n",
       "      <td>3.0</td>\n",
       "      <td>1.4</td>\n",
       "      <td>0.2</td>\n",
       "      <td>Iris-setosa</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>4.7</td>\n",
       "      <td>3.2</td>\n",
       "      <td>1.3</td>\n",
       "      <td>0.2</td>\n",
       "      <td>Iris-setosa</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4.6</td>\n",
       "      <td>3.1</td>\n",
       "      <td>1.5</td>\n",
       "      <td>0.2</td>\n",
       "      <td>Iris-setosa</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>5.0</td>\n",
       "      <td>3.6</td>\n",
       "      <td>1.4</td>\n",
       "      <td>0.2</td>\n",
       "      <td>Iris-setosa</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   花萼长度  花萼宽度  花瓣长度  花瓣宽度           种类  种类编码\n",
       "0   5.1   3.5   1.4   0.2  Iris-setosa     0\n",
       "1   4.9   3.0   1.4   0.2  Iris-setosa     0\n",
       "2   4.7   3.2   1.3   0.2  Iris-setosa     0\n",
       "3   4.6   3.1   1.5   0.2  Iris-setosa     0\n",
       "4   5.0   3.6   1.4   0.2  Iris-setosa     0"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "d4117ed9",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data['种类编码'] = train_data['种类'].map({'Iris-setosa': 0, 'Iris-versicolor': 1, 'Iris-virginica': 2})\n",
    "test_data['种类编码'] = test_data['种类'].map({'Iris-setosa': 0, 'Iris-versicolor': 1, 'Iris-virginica': 2})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "61e7866f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[5.1 3.5 1.4 0.2]\n",
      " [4.9 3.  1.4 0.2]\n",
      " [4.7 3.2 1.3 0.2]\n",
      " [4.6 3.1 1.5 0.2]\n",
      " [5.  3.6 1.4 0.2]\n",
      " [5.4 3.9 1.7 0.4]\n",
      " [4.6 3.4 1.4 0.3]\n",
      " [5.  3.4 1.5 0.2]\n",
      " [4.4 2.9 1.4 0.2]\n",
      " [4.9 3.1 1.5 0.1]\n",
      " [5.4 3.7 1.5 0.2]\n",
      " [4.8 3.4 1.6 0.2]\n",
      " [4.8 3.  1.4 0.1]\n",
      " [4.3 3.  1.1 0.1]\n",
      " [5.8 4.  1.2 0.2]\n",
      " [5.7 4.4 1.5 0.4]\n",
      " [5.4 3.9 1.3 0.4]\n",
      " [5.1 3.5 1.4 0.3]\n",
      " [5.7 3.8 1.7 0.3]\n",
      " [5.1 3.8 1.5 0.3]\n",
      " [7.  3.2 4.7 1.4]\n",
      " [6.4 3.2 4.5 1.5]\n",
      " [6.9 3.1 4.9 1.5]\n",
      " [5.5 2.3 4.  1.3]\n",
      " [6.5 2.8 4.6 1.5]\n",
      " [5.7 2.8 4.5 1.3]\n",
      " [6.3 3.3 4.7 1.6]\n",
      " [4.9 2.4 3.3 1. ]\n",
      " [6.6 2.9 4.6 1.3]\n",
      " [5.2 2.7 3.9 1.4]\n",
      " [5.  2.  3.5 1. ]\n",
      " [5.9 3.  4.2 1.5]\n",
      " [6.  2.2 4.  1. ]\n",
      " [6.1 2.9 4.7 1.4]\n",
      " [5.6 2.9 3.6 1.3]\n",
      " [6.7 3.1 4.4 1.4]\n",
      " [5.6 3.  4.5 1.5]\n",
      " [5.8 2.7 4.1 1. ]\n",
      " [6.2 2.2 4.5 1.5]\n",
      " [5.6 2.5 3.9 1.1]\n",
      " [6.3 3.3 6.  2.5]\n",
      " [5.8 2.7 5.1 1.9]\n",
      " [7.1 3.  5.9 2.1]\n",
      " [6.3 2.9 5.6 1.8]\n",
      " [6.5 3.  5.8 2.2]\n",
      " [7.6 3.  6.6 2.1]\n",
      " [4.9 2.5 4.5 1.7]\n",
      " [7.3 2.9 6.3 1.8]\n",
      " [6.7 2.5 5.8 1.8]\n",
      " [7.2 3.6 6.1 2.5]\n",
      " [6.5 3.2 5.1 2. ]\n",
      " [6.4 2.7 5.3 1.9]\n",
      " [6.8 3.  5.5 2.1]\n",
      " [5.7 2.5 5.  2. ]\n",
      " [5.8 2.8 5.1 2.4]\n",
      " [6.4 3.2 5.3 2.3]\n",
      " [6.5 3.  5.5 1.8]\n",
      " [7.7 3.8 6.7 2.2]\n",
      " [7.7 2.6 6.9 2.3]]\n",
      "[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
      " 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2]\n",
      "[[5.4 3.4 1.7 0.2]\n",
      " [5.1 3.7 1.5 0.4]\n",
      " [4.6 3.6 1.  0.2]\n",
      " [5.1 3.3 1.7 0.5]\n",
      " [4.8 3.4 1.9 0.2]\n",
      " [5.  3.  1.6 0.2]\n",
      " [5.  3.4 1.6 0.4]\n",
      " [5.2 3.5 1.5 0.2]\n",
      " [5.2 3.4 1.4 0.2]\n",
      " [5.9 3.2 4.8 1.8]\n",
      " [6.1 2.8 4.  1.3]\n",
      " [6.3 2.5 4.9 1.5]\n",
      " [6.1 2.8 4.7 1.2]\n",
      " [6.4 2.9 4.3 1.3]\n",
      " [6.6 3.  4.4 1.4]\n",
      " [6.8 2.8 4.8 1.4]\n",
      " [6.7 3.  5.  1.7]\n",
      " [6.  2.9 4.5 1.5]\n",
      " [5.7 2.6 3.5 1. ]\n",
      " [5.5 2.4 3.8 1.1]\n",
      " [6.  2.2 5.  1.5]\n",
      " [6.9 3.2 5.7 2.3]\n",
      " [5.6 2.8 4.9 2. ]\n",
      " [7.7 2.8 6.7 2. ]\n",
      " [6.3 2.7 4.9 1.8]\n",
      " [6.7 3.3 5.7 2.1]\n",
      " [7.2 3.2 6.  1.8]\n",
      " [6.2 2.8 4.8 1.8]\n",
      " [6.1 3.  4.9 1.8]\n",
      " [6.4 2.8 5.6 2.1]\n",
      " [7.2 3.  5.8 1.6]\n",
      " [7.4 2.8 6.1 1.9]]\n",
      "[0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2]\n",
      "X_train.shape=(59, 4)\n",
      " y_train.shape =(59,)\n",
      " X_test.shape=(32, 4)\n",
      " y_test.shape=(32,)\n"
     ]
    }
   ],
   "source": [
    "x_train = np.array(train_data.iloc[:, 0:4])\n",
    "y_train = np.array(train_data.iloc[:, 5:6].values.ravel())\n",
    "x_test = np.array(test_data.iloc[:, 0:4])\n",
    "y_test = np.array(test_data.iloc[:, 5:6].values.ravel())\n",
    "\n",
    "print(x_train)\n",
    "print(y_train)\n",
    "print(x_test)\n",
    "print(y_test)\n",
    "\n",
    "print('X_train.shape={}\\n y_train.shape ={}\\n X_test.shape={}\\n y_test.shape={}'.format(x_train.shape,\n",
    "                                                                                            y_train.shape,\n",
    "                                                                                            x_test.shape,\n",
    "                                                                                            y_test.shape))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "31830318",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LogisticRegression()"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lr_model = LogisticRegression()\n",
    "lr_model.fit(x_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "dda4ad1f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 7.45220115,  1.40225101, -8.85445216])"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bias = lr_model.intercept_\n",
    "bias"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "53dccad5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[-0.31283774,  0.7144161 , -1.92426567, -0.80157387],\n",
       "       [ 0.52046943, -0.41584669, -0.1519451 , -0.619048  ],\n",
       "       [-0.20763169, -0.2985694 ,  2.07621077,  1.42062187]])"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "weight = lr_model.coef_\n",
    "weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "e6a3b7ae",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2,\n",
       "       2, 2, 2, 2, 2, 1, 2, 2, 2, 2], dtype=int64)"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_pred = lr_model.predict(x_test)\n",
    "\n",
    "y_pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "19604963",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.90625"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sum_mean = 0\n",
    "\n",
    "score = lr_model.score(x_test, y_test)\n",
    "\n",
    "score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75f4c9d2",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "ee79d42d",
   "metadata": {},
   "outputs": [
    {
     "ename": "MemoryError",
     "evalue": "Unable to allocate 16.2 GiB for an array with shape (171, 220, 341, 170) and data type float64",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mMemoryError\u001b[0m                               Traceback (most recent call last)",
      "\u001b[1;32mC:\\Users\\LIANGZ~1\\AppData\\Local\\Temp/ipykernel_14760/3338040208.py\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m      6\u001b[0m \u001b[0mx4_min\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mx4_max\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mx_train\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m3\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmin\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m-\u001b[0m \u001b[1;36m.5\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mx_train\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m3\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmax\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m+\u001b[0m \u001b[1;36m.5\u001b[0m \u001b[1;31m# 第3列的范围'\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      7\u001b[0m \u001b[0mh\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;36m.02\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 8\u001b[1;33m \u001b[0mx1\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mx2\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mx3\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mx4\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmeshgrid\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0marange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx1_min\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mx1_max\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mh\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0marange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx2_min\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mx2_max\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mh\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0marange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx3_min\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mx3_max\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mh\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0marange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx4_min\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mx4_max\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mh\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;31m# 生成网格采样点\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m      9\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     10\u001b[0m \u001b[0mgrid_test\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstack\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx1\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mflat\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mx2\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mflat\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mx3\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mflat\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mx4\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mflat\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m  \u001b[1;31m# 测试点\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m<__array_function__ internals>\u001b[0m in \u001b[0;36mmeshgrid\u001b[1;34m(*args, **kwargs)\u001b[0m\n",
      "\u001b[1;32mF:\\anaconda3\\envs\\machine-learning\\lib\\site-packages\\numpy\\lib\\function_base.py\u001b[0m in \u001b[0;36mmeshgrid\u001b[1;34m(copy, sparse, indexing, *xi)\u001b[0m\n\u001b[0;32m   4371\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   4372\u001b[0m     \u001b[1;32mif\u001b[0m \u001b[0mcopy\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 4373\u001b[1;33m         \u001b[0moutput\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcopy\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mx\u001b[0m \u001b[1;32min\u001b[0m \u001b[0moutput\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   4374\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   4375\u001b[0m     \u001b[1;32mreturn\u001b[0m \u001b[0moutput\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mF:\\anaconda3\\envs\\machine-learning\\lib\\site-packages\\numpy\\lib\\function_base.py\u001b[0m in \u001b[0;36m<listcomp>\u001b[1;34m(.0)\u001b[0m\n\u001b[0;32m   4371\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   4372\u001b[0m     \u001b[1;32mif\u001b[0m \u001b[0mcopy\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 4373\u001b[1;33m         \u001b[0moutput\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcopy\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mx\u001b[0m \u001b[1;32min\u001b[0m \u001b[0moutput\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   4374\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   4375\u001b[0m     \u001b[1;32mreturn\u001b[0m \u001b[0moutput\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mMemoryError\u001b[0m: Unable to allocate 16.2 GiB for an array with shape (171, 220, 341, 170) and data type float64"
     ]
    }
   ],
   "source": [
    "# Plot the decision boundary. For that, we will assign a color to each\n",
    "# point in the mesh [x_min, x_max]x[y_min, y_max].\n",
    "x1_min, x1_max = x_train[:, 0].min() - .5, x_train[:, 0].max() + .5 # 第0列的范围\n",
    "x2_min, x2_max = x_train[:, 1].min() - .5, x_train[:, 1].max() + .5 # 第1列的范围\n",
    "x3_min, x3_max = x_train[:, 2].min() - .5, x_train[:, 2].max() + .5 # 第2列的范围\n",
    "x4_min, x4_max = x_train[:, 3].min() - .5, x_train[:, 3].max() + .5 # 第3列的范围'\n",
    "h = .02\n",
    "x1, x2, x3, x4 = np.meshgrid(np.arange(x1_min, x1_max, h), np.arange(x2_min, x2_max, h), np.arange(x3_min, x3_max, h), np.arange(x4_min, x4_max, h)) # 生成网格采样点\n",
    "\n",
    "grid_test = np.stack((x1.flat, x2.flat, x3.flat, x4.flat), axis=1)  # 测试点\n",
    "grid_hat = lr_model.predict(grid_test)                  # 预测分类值\n",
    "# grid_hat = lr.predict(np.c_[x1.ravel(), x2.ravel()])\n",
    "grid_hat = grid_hat.reshape(x1.shape)             # 使之与输入的形状相同"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ee79b2e",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(1, figsize=(6, 5))\n",
    "# 预测值的显示, 输出为三个颜色区块，分布表示分类的三类区域\n",
    "plt.pcolormesh(x1, x2, x3, x4, grid_hat,cmap=plt.cm.Paired) \n",
    "\n",
    "# plt.scatter(x_train[:, 0], x_train[:, 1], c=Y,edgecolors='k', cmap=plt.cm.Paired)\n",
    "plt.scatter(x_train[:20, 0], x_train[:20, 1], marker = '*', edgecolors='red', label='setosa')\n",
    "plt.scatter(x_train[20:40, 0], x_train[20:40, 1], marker = '+', edgecolors='k', label='versicolor')\n",
    "plt.scatter(x_train[40:60, 0], x_train[40:60, 1], marker = 'o', edgecolors='k', label='virginica')\n",
    "plt.xlabel('花萼长度-Sepal length')\n",
    "plt.ylabel('花萼宽度-Sepal width')\n",
    "plt.legend(loc = 2)\n",
    "\n",
    "plt.xlim(x1.min(), x1.max())\n",
    "plt.ylim(x2.min(), x2.max())\n",
    "plt.title(\"Logistic Regression 鸢尾花分类结果\", fontsize = 15)\n",
    "plt.xticks(())\n",
    "plt.yticks(())\n",
    "plt.grid()\n",
    "\n",
    "plt.show()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
